"""
Phase 1: Data Assembly — Real Exchange Rate Panel
===================================================
Merges full_panel.csv with BIS REER data (from asset_returns pipeline)
and IMF REER data (from clearing_channels). Adds Balassa-Samuelson
controls, terms of trade, and non-tradable share proxies.

Key extension over Paper 6 (asset_returns phase 3): dedicated treatment
with proper BS controls, global sample (not just OECD), and PPP
deviation analysis.

Output: rer/data/processed/rer_panel.csv
Tables: summary_statistics.md
"""

import sys
from pathlib import Path

import numpy as np
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

# ── Paths ──────────────────────────────────────────────────────────────────
PROJECT_DIR = Path("/mnt/c/demographics_capital_flows/rer")
ROOT_DIR = PROJECT_DIR.parent
MULTILATERAL_DIR = ROOT_DIR / "multilateral"
ASSET_DIR = ROOT_DIR / "asset_returns"
CLEARING_DIR = MULTILATERAL_DIR / "clearing_channels"
PROCESSED_DIR = PROJECT_DIR / "data" / "processed"
TABLES_DIR = PROJECT_DIR / "output" / "tables"

for d in [PROCESSED_DIR, TABLES_DIR]:
    d.mkdir(parents=True, exist_ok=True)

sys.path.insert(0, str(MULTILATERAL_DIR / "src"))

OECD_38 = [
    "AUS", "AUT", "BEL", "CAN", "CHL", "COL", "CRI", "CZE", "DNK", "EST",
    "FIN", "FRA", "DEU", "GRC", "HUN", "ISL", "IRL", "ISR", "ITA", "JPN",
    "KOR", "LVA", "LTU", "LUX", "MEX", "NLD", "NZL", "NOR", "POL", "PRT",
    "SVK", "SVN", "ESP", "SWE", "CHE", "TUR", "GBR", "USA",
]

# Eurozone with time-varying membership
EMU_MEMBERS = {
    'AUT': 1999, 'BEL': 1999, 'FIN': 1999, 'FRA': 1999, 'DEU': 1999,
    'IRL': 1999, 'ITA': 1999, 'LUX': 1999, 'NLD': 1999, 'PRT': 1999,
    'ESP': 1999, 'GRC': 2001, 'SVN': 2007, 'CYP': 2008, 'MLT': 2008,
    'SVK': 2009, 'EST': 2011, 'LVA': 2014, 'LTU': 2015,
}


def main():
    print("=" * 70)
    print("PHASE 1: Data Assembly — Real Exchange Rate Panel")
    print("=" * 70)

    # ── [1] Load full_panel.csv ──
    print("\n[1] Loading full_panel.csv ...")
    full_path = MULTILATERAL_DIR / "followup" / "data" / "processed" / "full_panel.csv"
    df = pd.read_csv(full_path)
    df = df[df['year'] <= 2024].copy()
    print(f"  Full panel: {df['iso3'].nunique()} countries, {len(df):,} obs")

    # ── [2] Load BIS REER from asset_returns processed panel ──
    print("\n[2] Loading BIS REER from asset_returns ...")
    asset_path = ASSET_DIR / "data" / "processed" / "asset_panel.csv"
    if asset_path.exists():
        asset_df = pd.read_csv(asset_path, usecols=['iso3', 'year', 'reer', 'log_reer', 'd_reer'])
        asset_df = asset_df.dropna(subset=['reer']).drop_duplicates(subset=['iso3', 'year'])
        n_bis = asset_df['iso3'].nunique()
        print(f"  BIS REER: {n_bis} countries, {len(asset_df):,} obs")
        df = df.merge(asset_df, on=['iso3', 'year'], how='left')
    else:
        print("  asset_panel.csv not found — will rely on IMF REER only")

    # ── [3] Load IMF REER from clearing_channels ──
    print("\n[3] Loading IMF REER from clearing_channels ...")
    imf_path = CLEARING_DIR / "data" / "raw" / "reer_annual.csv"
    if imf_path.exists():
        imf_df = pd.read_csv(imf_path)
        # Standardize column names
        if 'ref_area' in imf_df.columns:
            imf_df = imf_df.rename(columns={'ref_area': 'iso3', 'value': 'reer_imf'})
        elif 'iso3' not in imf_df.columns:
            # Try to identify columns
            print(f"  IMF REER columns: {list(imf_df.columns)}")
        if 'reer' in imf_df.columns and 'reer_imf' not in imf_df.columns:
            imf_df = imf_df.rename(columns={'reer': 'reer_imf'})

        if 'iso3' in imf_df.columns and 'reer_imf' in imf_df.columns:
            imf_df = imf_df[['iso3', 'year', 'reer_imf']].dropna().drop_duplicates(
                subset=['iso3', 'year'])
            n_imf = imf_df['iso3'].nunique()
            print(f"  IMF REER: {n_imf} countries, {len(imf_df):,} obs")
            df = df.merge(imf_df, on=['iso3', 'year'], how='left')
        else:
            print("  Could not parse IMF REER — skipping")
    else:
        print("  IMF REER file not found — skipping")

    # ── [4] Construct unified REER variable ──
    print("\n[4] Constructing unified REER variables ...")

    # Use BIS as primary, fill with IMF where BIS missing
    if 'reer' not in df.columns:
        df['reer'] = np.nan
    if 'reer_imf' in df.columns:
        df['reer_combined'] = df['reer'].fillna(df['reer_imf'])
    else:
        df['reer_combined'] = df['reer']

    df['log_reer_combined'] = np.log(df['reer_combined'].clip(lower=1))
    df.loc[df['reer_combined'].isna(), 'log_reer_combined'] = np.nan

    # Annual change in log REER
    df = df.sort_values(['iso3', 'year'])
    df['d_log_reer'] = df.groupby('iso3')['log_reer_combined'].diff()

    # 5-year MA
    df['log_reer_ma5'] = df.groupby('iso3')['log_reer_combined'].transform(
        lambda x: x.rolling(5, min_periods=3).mean())

    # Detrended (HP-style: deviation from country mean)
    df['log_reer_demeaned'] = df.groupby('iso3')['log_reer_combined'].transform(
        lambda x: x - x.mean())

    n_reer = df['reer_combined'].notna().sum()
    n_countries_reer = df.loc[df['reer_combined'].notna(), 'iso3'].nunique()
    print(f"  Combined REER: {n_countries_reer} countries, {n_reer:,} obs")

    # ── [5] Balassa-Samuelson controls ──
    print("\n[5] Constructing Balassa-Samuelson controls ...")

    # GDP per capita (PPP) as productivity proxy
    if 'gdp_pc_ppp' in df.columns:
        df['log_gdp_pc'] = np.log(df['gdp_pc_ppp'].clip(lower=100))
        df.loc[df['gdp_pc_ppp'].isna(), 'log_gdp_pc'] = np.nan
        print(f"  log_gdp_pc: {df['log_gdp_pc'].notna().sum():,} obs")

    # Output per worker (relative to frontier) — already in full_panel
    if 'output_per_worker' in df.columns:
        df['log_opw'] = np.log(df['output_per_worker'].clip(lower=100))
        df.loc[df['output_per_worker'].isna(), 'log_opw'] = np.nan
        print(f"  log_opw: {df['log_opw'].notna().sum():,} obs")

    # Health expenditure as non-tradable share proxy
    if 'health_exp_gdp' in df.columns:
        print(f"  health_exp_gdp (non-tradable proxy): {df['health_exp_gdp'].notna().sum():,} obs")

    # ── [6] PPP deviation ──
    print("\n[6] PPP deviation measures ...")

    # If REER is indexed, deviation from 100 = PPP deviation
    df['reer_deviation'] = df['reer_combined'] - 100.0

    # Country-specific long-run mean deviation
    country_mean = df.groupby('iso3')['log_reer_combined'].mean().rename('log_reer_country_mean')
    df = df.merge(country_mean.reset_index(), on='iso3', how='left')
    df['reer_persistence'] = df['log_reer_combined'] - df['log_reer_country_mean']

    # ── [7] Regime variables ──
    print("\n[7] Regime variables ...")

    # OECD dummy
    df['oecd'] = df['iso3'].isin(OECD_38).astype(int)

    # Eurozone time-varying
    df['eurozone'] = 0
    for iso3, join_year in EMU_MEMBERS.items():
        df.loc[(df['iso3'] == iso3) & (df['year'] >= join_year), 'eurozone'] = 1
    n_emu = df['eurozone'].sum()
    print(f"  Eurozone obs: {n_emu:,}")

    # Income groups
    if 'gdp_pc_ppp' in df.columns:
        def safe_qcut(x):
            try:
                return pd.qcut(x, 3, labels=['low', 'mid', 'high'], duplicates='drop')
            except (ValueError, IndexError):
                return pd.Series(np.nan, index=x.index)
        df['income_group'] = df.groupby('year')['gdp_pc_ppp'].transform(safe_qcut)

    # ── [8] Interaction terms ──
    print("\n[8] Interaction terms ...")

    for zv in ['Z_1', 'Z_2', 'Z_3']:
        if zv in df.columns:
            # Z × NFA
            if 'nfa_gdp_lag' in df.columns:
                df[f'{zv}_x_nfa'] = df[zv] * df['nfa_gdp_lag']
            # Z × KAOPEN
            if 'kaopen' in df.columns:
                df[f'{zv}_x_kaopen'] = df[zv] * df['kaopen']
            # Z × trade openness
            if 'trade_openness' in df.columns:
                df[f'{zv}_x_trade'] = df[zv] * df['trade_openness']
            # Z × eurozone
            df[f'{zv}_x_emu'] = df[zv] * df['eurozone']

    # ── [9] Lagged and differenced demographics ──
    print("\n[9] Lagged/differenced demographics ...")

    for zv in ['Z_1', 'Z_2', 'Z_3']:
        if zv in df.columns:
            df[f'{zv}_lag5'] = df.groupby('iso3')[zv].shift(5)
            df[f'd_{zv}'] = df.groupby('iso3')[zv].diff()

    # Predetermined demographics
    if 'old_dep' in df.columns:
        df['oadr_plus20'] = df.groupby('iso3')['old_dep'].shift(-20)

    # ── [10] Restrict to panel with REER data ──
    print("\n[10] Panel summary ...")
    df = df[(df['year'] >= 1990) & (df['year'] <= 2024)].copy()

    n_total = len(df)
    n_countries = df['iso3'].nunique()
    n_with_reer = df.loc[df['reer_combined'].notna(), 'iso3'].nunique()
    n_oecd_reer = df.loc[df['reer_combined'].notna() & (df['oecd'] == 1), 'iso3'].nunique()
    n_nooecd_reer = df.loc[df['reer_combined'].notna() & (df['oecd'] == 0), 'iso3'].nunique()

    print(f"  Total panel: {n_countries} countries, {n_total:,} obs")
    print(f"  Countries with REER: {n_with_reer} (OECD: {n_oecd_reer}, non-OECD: {n_nooecd_reer})")
    print(f"  REER obs: {df['reer_combined'].notna().sum():,}")

    # ── [11] Summary statistics ──
    print("\n[11] Building summary statistics ...")
    key_vars = ['reer_combined', 'log_reer_combined', 'd_log_reer',
                'reer_deviation', 'log_reer_demeaned',
                'Z_1', 'Z_2', 'Z_3', 'old_dep', 'youth_dep',
                'log_gdp_pc', 'log_opw', 'trade_openness',
                'fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth', 'kaopen',
                'health_exp_gdp']
    key_vars = [v for v in key_vars if v in df.columns]

    md = ["# Summary Statistics — Real Exchange Rate Panel\n"]
    md.append("| Variable | N | Mean | SD | Min | Max |")
    md.append("|---|---|---|---|---|---|")
    for v in key_vars:
        s = df[v].dropna()
        if len(s) > 0:
            md.append(f"| {v} | {len(s):,} | {s.mean():.3f} | {s.std():.3f} "
                      f"| {s.min():.3f} | {s.max():.3f} |")

    md.append(f"\n*Panel: {n_countries} countries, {df['year'].min()}-{df['year'].max()}. "
              f"REER coverage: {n_with_reer} countries.*")

    out = TABLES_DIR / "summary_statistics.md"
    out.write_text('\n'.join(md))
    print(f"  Saved: {out}")

    # ── [12] Save ──
    print("\n[12] Saving rer_panel.csv ...")
    df.to_csv(PROCESSED_DIR / "rer_panel.csv", index=False)
    print(f"  Saved: {PROCESSED_DIR / 'rer_panel.csv'}")
    print(f"  Shape: {df.shape[0]:,} obs x {df.shape[1]} columns")

    print("\n" + "=" * 70)
    print("Phase 1 complete.")
    print("=" * 70)

    return df


if __name__ == "__main__":
    df = main()
